{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bac3a6cd",
   "metadata": {},
   "source": [
    "- 原理上都适用于分类问题\n",
    "    - bce：二分类问题 （0-1）\n",
    "        - 概率的过程通过 sigmoid，\n",
    "        - 1d(scalar) 的 output `logits` => 2d `[1-sigmoid(z), sigmoid(z)]`\n",
    "    - cross_entropy：多分类问题（`0, 1, 2, ... , C-1`）\n",
    "        - 概率化的过程通过 softmax，\n",
    "        - log softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "39602fc9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:25:23.813874Z",
     "start_time": "2023-11-26T10:25:23.806626Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "8a0d30b5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:46:14.026352Z",
     "start_time": "2023-11-26T10:46:14.014195Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.5031, -0.1608,  1.1623], requires_grad=True)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# prediction\n",
    "# number samples: 3\n",
    "input = torch.randn(3, requires_grad=True)\n",
    "input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "2d5f5258",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:54:17.246800Z",
     "start_time": "2023-11-26T10:54:17.239182Z"
    }
   },
   "outputs": [],
   "source": [
    "# true\n",
    "target = torch.empty(3, dtype=torch.long).random_(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "edad24b1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:54:18.515522Z",
     "start_time": "2023-11-26T10:54:18.504170Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 1])"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96b9924e",
   "metadata": {},
   "source": [
    "## F.binary_cross_entropy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05a5c9a1",
   "metadata": {},
   "source": [
    "- F.binary_cross_entropy 的 target 可以是浮点数，未必非得是 0-1 (integer)\n",
    "    - VAE 在图像重构时的 loss，通过 bce loss 定义和计算的；"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "26b93e19",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:54:20.291711Z",
     "start_time": "2023-11-26T10:54:20.283599Z"
    }
   },
   "outputs": [],
   "source": [
    "loss = F.binary_cross_entropy(torch.sigmoid(input), target.to(torch.float))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "eb777c7e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:54:21.719561Z",
     "start_time": "2023-11-26T10:54:21.708177Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.6750, grad_fn=<BinaryCrossEntropyBackward0>)"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "ad3b2515",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:54:22.647777Z",
     "start_time": "2023-11-26T10:54:22.633838Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.6750, grad_fn=<NegBackward0>)"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "-torch.mean(target * torch.log(F.sigmoid(input)) + (1-target) * torch.log(1-F.sigmoid(input)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bc4a899",
   "metadata": {},
   "source": [
    "## nn.BCELoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "1b90daa6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:54:24.222542Z",
     "start_time": "2023-11-26T10:54:24.210501Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.6750, grad_fn=<BinaryCrossEntropyBackward0>)"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = nn.BCELoss()\n",
    "m(torch.sigmoid(input), target.to(torch.float))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93fb5059",
   "metadata": {},
   "source": [
    "## F.cross_entropy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7a410ac",
   "metadata": {},
   "source": [
    "- 1d => 2d\n",
    "- ignore_index=-100,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "cfdf9940",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:54:25.816943Z",
     "start_time": "2023-11-26T10:54:25.804104Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.9760, -0.4729],\n",
       "        [-0.6160, -0.7768],\n",
       "        [-1.4345, -0.2721]], grad_fn=<LogBackward0>)"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_2 = torch.stack([1-torch.sigmoid(input), torch.sigmoid(input)], dim=1)\n",
    "# log softmax\n",
    "input_2 = torch.log(input_2)\n",
    "input_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "2be7d192",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:54:27.122765Z",
     "start_time": "2023-11-26T10:54:27.112910Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 1])"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "6206b200",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:54:28.236956Z",
     "start_time": "2023-11-26T10:54:28.225731Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.6750, grad_fn=<NllLossBackward0>)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.nll_loss(input_2, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "c127e0f1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-11-26T10:58:39.352335Z",
     "start_time": "2023-11-26T10:58:39.342689Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6749666666666667"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "-((-0.9760 + (-0.7768)+(-0.2721))/3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
